# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import typing
import unittest
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Optional, Tuple

import executorch.exir as exir

import executorch.exir.schema as schema
import executorch.exir.tests.models as models
import pytest
import torch
from executorch.exir import (
    EdgeCompileConfig,
    ExecutorchBackendConfig,
    ExecutorchProgramManager,
    to_edge,
)
from executorch.exir._serialize._program import deserialize_pte_binary
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
    ExecutorBackendPartitioner,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.emit import emit_program  # noqa
from executorch.exir.error import InternalError
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.print_program import pretty_print, print_program  # noqa
from executorch.exir.schema import (
    Bool,
    DelegateCall,
    Double,
    EValue,
    ExecutionPlan,
    Int,
    IntList,
    JumpFalseCall,
    KernelCall,
    KernelTypes,
    MoveCall,
    Null,
    OptionalTensorList,
    Program,
    String,
    Tensor,
)
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tests.models import Mul
from executorch.extension.pybindings.portable_lib import (
    _load_for_executorch_from_buffer,
)
from executorch.runtime import Runtime

from functorch.experimental import control_flow
from torch import nn

from torch.export import Dim, export
from torch.export.experimental import _export_forward_backward


class WrapperModule(torch.nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, *args, **kwargs):
        return self.fn(*args, **kwargs)


@contextmanager
def patch_forward(obj: torch.nn.Module, new_method):
    """Helper method to make it easier to cleanly torch.export() a method on a
    module that is not `forward`.

    TODO(suo): upstream this to torch.export.wrapper.
    """
    # Save the original method
    original_method = obj.forward

    # Patch the method
    obj.forward = new_method.__get__(obj, obj.__class__)

    try:
        yield
    finally:
        # Restore the original method
        obj.forward = original_method


class TestEmit(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        register_additional_test_aten_ops()

    def setUp(self) -> None:
        self.compile_config = EdgeCompileConfig(_check_ir_validity=False)

    def check_tensor_buffer_loc(
        self,
        value_index: int,
        values: List[EValue],
        exp_buffer_idx: int,
        exp_mem_id: Optional[int],
        exp_mem_offset: Optional[int],
    ) -> None:
        value = typing.cast(schema.Tensor, values[value_index].val)
        self.assertIsInstance(value, schema.Tensor)

        self.assertEqual(value.data_buffer_idx, exp_buffer_idx)

        if not value.allocation_info:
            self.assertIsNone(exp_mem_id)
            self.assertIsNone(exp_mem_offset)
        else:
            self.assertEqual(value.allocation_info.memory_id, exp_mem_id)
            assert value.allocation_info
            self.assertEqual(value.allocation_info.memory_offset, exp_mem_offset)

    def count_node(self, graph_module: torch.fx.GraphModule, opname: str) -> int:
        return [
            node.target._overloadpacket._qualified_op_name
            for node in graph_module.graph.nodes
            if node.op == "call_function"
        ].count(opname)

    def run_dce(self, graph_module: torch.fx.GraphModule) -> None:
        for submodule in graph_module.modules():
            self.assertIsInstance(submodule, torch.fx.GraphModule)
            typing.cast(torch.fx.GraphModule, submodule).graph.eliminate_dead_code()

    def check_value_types(self, values: List[EValue]) -> None:
        for value in values:
            self.assertTrue(type(value.val) in KernelTypes.__args__)

    def count_move_instructions(self, program: Program) -> int:
        instructions = program.execution_plan[0].chains[0].instructions
        assert instructions is not None
        res = 0
        for instr in instructions:
            if isinstance(instr.instr_args, MoveCall):
                res += 1
        return res

    def test_basic_api(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x * y + x

        f = Foo()

        program = (
            to_edge(export(f, (torch.ones(3, 2), torch.zeros(3, 2)), strict=True))
            .to_executorch()
            .executorch_program
        )
        exec_plan = program.execution_plan[0]
        ops = exec_plan.operators
        for op in ops:
            self.assertEqual(op.overload, "out")

        self.assertEqual(ops[0].name, "aten::mul")
        self.assertEqual(ops[1].name, "aten::add")

        self.assertEqual(len(exec_plan.inputs), 2)
        self.assertEqual(len(exec_plan.outputs), 1)

        self.assertEqual(exec_plan.inputs[0], 0)
        self.assertEqual(exec_plan.outputs[0], 3)

    def test_basic_end_to_end(self) -> None:
        f = models.BasicSinMax()
        program = (
            to_edge(export(f, f.get_random_inputs(), strict=True))
            .to_executorch()
            .executorch_program
        )
        exec_plan = program.execution_plan[0]
        ops = exec_plan.operators
        for op in ops:
            self.assertIn(op.overload, {"out", "unary_out"})

        self.assertEqual(ops[0].name, "aten::sin")

        self.assertEqual(len(exec_plan.inputs), 1)
        self.assertEqual(len(exec_plan.outputs), 1)

        self.assertEqual(exec_plan.inputs[0], 0)
        self.assertEqual(exec_plan.outputs[0], 1)

    @pytest.mark.skip(reason="Test not working on OSS")
    def test_nested_return(self) -> None:
        class Foo(torch.nn.Module):
            def forward(
                self, x: torch.Tensor
            ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
                return (
                    torch.tensor(1),
                    torch.tensor(2),
                    [torch.sin(x).max(), torch.cos(x).max()],
                )

        f = Foo()

        x = (torch.randn(100),)
        program = to_edge(export(f, x, strict=True)).to_executorch().executorch_program
        exec_plan = program.execution_plan[0]
        self.assertEqual(len(exec_plan.outputs), 4)
        self.assertEqual(len(exec_plan.inputs), 1)

        self.assertEqual(
            program.execution_plan[0].container_meta_type.encoded_out_str,
            "T3#1#1#2($,$,L2#1#1($,$))",
        )

        self.assertEqual(
            program.execution_plan[0].container_meta_type.encoded_inp_str,
            "T2#1#0(T1#1($),D0())",
        )

    def test_constant_output(self):
        class M(torch.nn.Module):
            def forward(self, x):
                return [((1, 3, 1.2), True, [x + x, x * x], None)]

        ep = torch.export.export(M(), (torch.ones(2, 3),), strict=True)
        res = ep.module()(torch.ones(2, 3))
        self.assertEqual(res[0][0], (1, 3, 1.2))
        program = to_edge(ep).to_executorch().executorch_program
        outputs = program.execution_plan[0].outputs
        self.assertEqual(len(outputs), 7)
        self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
        self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
        self.assertEqual(
            program.execution_plan[0].values[outputs[2]].val.double_val, 1.2
        )
        self.assertEqual(
            program.execution_plan[0].values[outputs[3]].val.bool_val, True
        )
        self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)

    def test_initialized_mutable_buffer(self):
        """Test that mutable buffers can hold meaningful initialized state."""

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                # Mutable buffer with non-empty initial state.
                self.register_buffer("cache_pos", torch.arange(0, 10))

            def forward(self, x):
                self.cache_pos.add_(1)
                return self.cache_pos

        m = TestModule()
        example_inputs = (torch.ones(10),)
        ep = torch.export.export(m, example_inputs, strict=True)
        edge = to_edge(
            ep,
            compile_config=EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        )

        # Save a copy of the edge program since to_executorch is
        # stateful to some degree.
        edge_copy = deepcopy(edge)
        et_config = ExecutorchBackendConfig(
            passes=[InitializedMutableBufferPass(["cache_pos"])],
        )
        et_program_init_pass = edge.to_executorch(config=et_config)
        et_program_regular = edge_copy.to_executorch()

        runtime = Runtime.get()
        program_init_pass = runtime.load_program(et_program_init_pass.buffer)
        method_init_pass = program_init_pass.load_method("forward")

        program_regular = runtime.load_program(et_program_regular.buffer)
        method_regular = program_regular.load_method("forward")

        # Test that the mutable buffer is initialized.
        torch.allclose(
            method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
        )
        # Test that the mutable buffer is uninitialized and starts with default zeros,
        # we test equality with torch.ones because of the mutation += 1 in the model forward.
        torch.allclose(
            method_regular.execute((example_inputs))[0],
            torch.ones(10, dtype=torch.int64),
        )

    def test_int_list_input(self):
        class M(torch.nn.Module):
            def forward(self, x, y, z):
                return x + y, x + x, x + y + z

        ep = torch.export.export(M(), (torch.ones(2, 3), 2, True), strict=True)
        ep.module()(torch.ones(2, 3), 2, True)
        program = to_edge(ep).to_executorch().executorch_program
        inputs = program.execution_plan[0].inputs
        self.assertEqual(len(inputs), 3)
        self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2)
        self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True)

    def test_inplace_ops(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                y = torch.sin(x)
                z = y.view(100)
                torch.relu_(z)
                return z.max()

        f = Foo()

        inputs = (torch.ones((10, 10)),)
        edge = to_edge(export(f, inputs, strict=True))

        removed_ops = ["aten::relu_", "aten::view"]
        expected_ops = [
            "aten::sin",
            "aten::relu",
            "aten::max",
        ]

        def expected_view_ops(config):
            if config.remove_view_copy:
                return []
            else:
                return ["aten::view_copy"]

        for opname in removed_ops:
            self.assertEqual(
                self.count_node(edge.exported_program().graph_module, opname), 0
            )
        for opname in expected_ops:
            self.assertTrue(
                self.count_node(edge.exported_program().graph_module, opname) >= 1
            )

        for remove_view_copy in [True, False]:
            config = exir.ExecutorchBackendConfig(remove_view_copy=remove_view_copy)
            edge_copy = deepcopy(edge)
            program = edge_copy.to_executorch(config=config).executorch_program
            for opname in removed_ops:
                self.assertTrue(
                    all(op.name != opname for op in program.execution_plan[0].operators)
                )
            for opname in expected_ops + expected_view_ops(config):
                self.assertTrue(
                    any(op.name == opname for op in program.execution_plan[0].operators)
                )
            self.assertTrue(
                len(program.execution_plan[0].operators)
                == len(expected_ops + expected_view_ops(config))
            )

    def test_operators_unique(self) -> None:
        class OpRepeatedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(2, 2)
                self.b = 2 * torch.ones(2, 2)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                for _ in range(10):
                    z = self.a * x
                    y = z + self.b
                return y

        model = OpRepeatedModule()

        inputs = (torch.ones(2, 2),)

        program = (
            to_edge(export(model, inputs, strict=True))
            .to_executorch()
            .executorch_program
        )

        self.assertEqual(len(program.execution_plan[0].operators), 2)

    def test_list_type(self) -> None:
        """Tests that the types of lists are correctly found"""

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.permute(x, (2, 0, 1))

        f = Foo()

        program = (
            to_edge(export(f, (torch.randn(2, 3, 5),), strict=True))
            .to_executorch()
            .executorch_program
        )
        exir.print_program.pretty_print(program)

        deboxed_int_list = []
        for item in program.execution_plan[0].values[5].val.items:
            deboxed_int_list.append(program.execution_plan[0].values[item].val.int_val)

        self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1]))

    def test_kwargs1(self) -> None:
        """Tests that the kwargs are placed in the order specified by
        native_functions.yaml
        """

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                batch1 = torch.randn(10, 3, 4)
                batch2 = torch.randn(10, 4, 5)
                return torch.addbmm(x, batch1, batch2, alpha=2, beta=3)

        f = Foo()

        program = (
            to_edge(export(f, (torch.randn(3, 5),), strict=True))
            .to_executorch()
            .executorch_program
        )
        # The value for beta should appear before alpha
        self.assertEqual(program.execution_plan[0].values[12].val, Int(3))
        self.assertEqual(program.execution_plan[0].values[13].val, Int(2))

    def test_kwargs2(self) -> None:
        """Tests that the kwargs are placed in the order specified by
        native_functions.yaml
        """

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                values = torch.randn(3, 2)
                return torch.searchsorted(x, values, side="right", right=True)

        f = Foo()

        x, _ = torch.sort(torch.randn(3, 4))
        program = (
            to_edge(export(f, (x,), strict=True)).to_executorch().executorch_program
        )
        # The value for right should appear before side
        self.assertEqual(program.execution_plan[0].values[6].val, Bool(False))
        self.assertEqual(program.execution_plan[0].values[7].val, Bool(True))
        self.assertEqual(program.execution_plan[0].values[8].val, String("right"))
        self.assertEqual(program.execution_plan[0].values[9].val, Null())

    def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None:
        instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args

        if isinstance(instr_args, KernelCall) or isinstance(instr_args, DelegateCall):
            self.assertEqual(len(instr_args.args), expected_len)
        else:
            self.assertTrue(False)

    def test_out(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                z = y.clone()
                return torch.mul(x, y, out=z)

        f = Foo()

        program = (
            to_edge(export(f, (torch.ones(3), torch.ones(3)), strict=True))
            .to_executorch()
            .executorch_program
        )

        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1)
        self._assertCallLength(program, 0, 4)

    def test_model_out(self) -> None:
        class Module_out(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 3 * torch.ones(2, 2, dtype=torch.int32)
                self.b = 2 * torch.ones(2, 2, dtype=torch.int32)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                z = x.clone()
                torch.mul(self.a, x, out=z)
                y = x.clone()
                torch.add(z, self.b, alpha=2, out=y)
                return y

        model_out = Module_out()

        inputs = (torch.ones(2, 2, dtype=torch.int32),)

        # Trace to FX Graph.
        program = (
            to_edge(export(model_out, inputs, strict=True))
            .to_executorch()
            .executorch_program
        )

        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 2)
        self._assertCallLength(program, 0, 4)
        self._assertCallLength(program, 1, 5)

    def test_stacktrace(self) -> None:
        def f(x: torch.Tensor) -> torch.Tensor:
            return torch.mul(x, torch.randn(3, 2))

        def g(x: torch.Tensor) -> torch.Tensor:
            return torch.sin(f(x))

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.add(g(x), torch.randn(3, 2))

        h = Foo()

        x = (torch.randn(3, 2),)
        exec_prog = to_edge(export(h, x, strict=True)).to_executorch(
            exir.ExecutorchBackendConfig(emit_stacktrace=True)
        )
        program = exec_prog.executorch_program

        # Check the mul operator's stack trace contains f -> g -> h
        self.assertTrue(
            "return torch.mul(x, torch.randn(3, 2))"
            in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
        )
        self.assertEqual(
            program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
        )
        self.assertEqual(
            program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g"
        )
        self.assertEqual(
            program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward"
        )

        # Check the sin operator's stack trace contains g -> h
        self.assertEqual(
            program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g"
        )
        self.assertEqual(
            program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward"
        )

    def test_stacktrace_off(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.mul(x, torch.randn(3, 2))

        f = Foo()

        class Goo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.sin(f(x))

        g = Goo()

        class Hoo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.add(g(x), torch.randn(3, 2))

        h = Hoo()

        x = (torch.randn(3, 2),)
        program = to_edge(export(h, x, strict=True)).to_executorch().executorch_program

        # Check the stacktrace is None since we did not specify to get the stacktrace
        self.assertTrue(program.execution_plan[0].chains[0].stacktrace is None)

    def test_positional_argument_default_value(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
                z = torch.ones(6, 2)
                return torch.ops.aten.cat.out((x, n), out=z)

        f = Foo()

        x = torch.randn(3, 2)
        program = (
            to_edge(export(f, (x, x), strict=True))
            # .to_edge(self.compile_config)  # TODO(larryliu): fix cat
            .to_executorch().executorch_program
        )

        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1)
        self._assertCallLength(program, 0, 4)

    @pytest.mark.skip(reason="Test not working on OSS")
    def test_emit_multiple_out(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
                return torch.topk(x, 5)

        f = Foo()

        x = (torch.randn(10),)
        program = to_edge(export(f, x, strict=True)).to_executorch().executorch_program
        self._assertCallLength(program, 0, 8)

    def test_emit_layout(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.ones_like(x)

        f = Foo()

        x = (torch.randn(3, 2),)
        program = to_edge(export(f, x, strict=True)).to_executorch().executorch_program

        vals = program.execution_plan[0].values
        for val in vals:
            v = val.val
            if isinstance(v, Tensor):
                self.assertEqual(v.layout, 0)

    def test_optional_tensor_list(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                a = torch.nonzero(x)
                torch._constrain_as_size(a.shape[0], min=1)
                b = torch.ops.aten.index.Tensor(x, [a])
                return b

        f = Foo()
        x = (torch.triu(torch.ones(2, 2)),)
        program = (
            to_edge(
                export(f, x, strict=True),
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            )
            .to_executorch()
            .executorch_program
        )
        self.assertTrue(
            isinstance(program.execution_plan[0].values[3].val, OptionalTensorList)
        )
        self._assertCallLength(program, 0, 3)
        self._assertCallLength(program, 1, 4)

    def test_optional_float_list(self) -> None:
        class M(torch.nn.Module):
            def forward(self, x):
                return torch.nn.functional.interpolate(x, scale_factor=2)

        x = (torch.randn(1, 1, 2, 2, 2),)
        program = (
            to_edge(export(M(), x, strict=True)).to_executorch().executorch_program
        )
        self.assertIsInstance(
            program.execution_plan[0].values[-1].val, schema.OptionalTensorList
        )

    def test_emit_cond(self) -> None:
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, pred, x):
                def true_fn(y: torch.Tensor) -> torch.Tensor:
                    y = y + y
                    y = torch.mm(y, y)
                    return y

                def false_fn(y: torch.Tensor) -> torch.Tensor:
                    return torch.mm(y, y)

                ret = control_flow.cond(pred, true_fn, false_fn, [x])
                return ret

        module = to_edge(
            export(M(), (torch.tensor(True), torch.ones(2, 2)), strict=True)
        )
        program = module.to_executorch().executorch_program

        num_mm = 0
        num_add = 0
        num_other = 0
        for inst in program.execution_plan[0].chains[0].instructions:
            if not isinstance(inst.instr_args, KernelCall):
                continue

            op = program.execution_plan[0].operators[inst.instr_args.op_index].name

            if "mm" in op:
                num_mm += 1
            elif "add" in op:
                num_add += 1
            else:
                num_other += 1

        self.assertEqual(num_mm, 2)
        self.assertEqual(num_add, 1)
        self.assertEqual(num_other, 0)

    def test_emit_map(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x + y

                return control_flow.map(map_fn, x, y)

        f = Foo()

        inputs = (torch.ones(4, 4), torch.ones(4))
        module = to_edge(
            export(f, inputs, strict=True),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        program = module.to_executorch().executorch_program

        op_table = program.execution_plan[0].operators
        # The first two operators at the beginning of a map program should be sym_size
        # and select_copy, which is what we verify here. The first operator is to generate
        # the number of iterations and the second operator is to slice the input tensor to
        # generate the tensor on which this iteration will operate on.
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[0].instr_args.op_index
            ].name,
            "aten::sym_size",
        )
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[1].instr_args.op_index
            ].name,
            "aten::select_copy",
        )

        # The last three instructions in the map sub-program are:
        # - Calling the custom op to append the output of this iteration to the accumulator tensor
        # - Increment the iteration count.
        # - Then checking if we've completed all the iterations.
        # We check here that both of these have been generated.
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[-5].instr_args.op_index
            ].name,
            "executorch_prim::et_copy_index",
        )
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[-4].instr_args.op_index
            ].name,
            "executorch_prim::add",
        )
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[-3].instr_args.op_index
            ].name,
            "executorch_prim::eq",
        )
        # The last two instructions in the overall program check if we should jump back to the
        # beginning of the loop and then resets the iteration counter if we fall through.
        self.assertTrue(
            isinstance(
                program.execution_plan[0].chains[0].instructions[-2].instr_args,
                JumpFalseCall,
            )
        )
        self.assertEqual(
            op_table[
                program.execution_plan[0].chains[0].instructions[-1].instr_args.op_index
            ].name,
            "executorch_prim::sub",
        )

    def test_load_emit_map(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x + y

                return control_flow.map(map_fn, x, y)

        f = Foo()

        inputs = (torch.ones(4, 4), torch.ones(4))
        module = to_edge(
            export(f, inputs, strict=True),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        _load_for_executorch_from_buffer(module.to_executorch().buffer)

    def test_run_emit_map(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x + y

                return control_flow.map(map_fn, x, y)

        f = Foo()

        inputs = (torch.ones(4, 4), torch.ones(4))
        module = to_edge(
            export(f, inputs, strict=True),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        buffer = module.to_executorch().buffer
        loaded_model = _load_for_executorch_from_buffer(buffer)
        outputs = loaded_model(inputs)[0]
        torch.allclose(outputs, f(*inputs))

    def test_dim_order(self) -> None:
        class SimpleLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.relu(self.linear(x))

        model = SimpleLinear()
        inputs = (torch.ones(10, 5),)
        program = (
            to_edge(
                export(model, inputs, strict=True),
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            )
            .to_executorch()
            .executorch_program
        )

        addmm_found = False
        for inst in program.execution_plan[0].chains[0].instructions:
            kernel = inst.instr_args
            if isinstance(kernel, KernelCall):
                op_id = kernel.op_index
                op = program.execution_plan[0].operators[op_id]
                if op.name == "aten::addmm":
                    addmm_found = True
                    args = kernel.args
                    bias_id = args[0]
                    act_id = args[1]
                    weight_id = args[2]
                    bias_dim_order = [0]
                    act_dim_order = [0, 1]
                    weight_dim_order = [0, 1]
                    bias_tensor = typing.cast(
                        schema.Tensor, program.execution_plan[0].values[bias_id].val
                    )
                    act_tensor = typing.cast(
                        schema.Tensor, program.execution_plan[0].values[act_id].val
                    )
                    weight_tensor = typing.cast(
                        schema.Tensor, program.execution_plan[0].values[weight_id].val
                    )
                    self.assertTrue(bias_tensor.dim_order == bias_dim_order)
                    self.assertTrue(act_tensor.dim_order == act_dim_order)
                    self.assertTrue(weight_tensor.dim_order == weight_dim_order)
        self.assertTrue(addmm_found)

    def test_non_const_buffer_sizes(self) -> None:
        class Add(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                b = 3 + 1
                return x + torch.tensor(b)

        f = Add()

        edge_program_manager = to_edge(export(f, (torch.ones(3, 2),), strict=True))
        edge_program_manager._edge_programs["forward"] = constant_prop_pass(
            edge_program_manager.exported_program()
        )
        non_const_buffer_size_with_const_prop_pass = (
            edge_program_manager.to_executorch()
            .executorch_program.execution_plan[0]
            .non_const_buffer_sizes
        )

        edge_program_manager = to_edge(export(f, (torch.ones(3, 2),), strict=True))
        non_const_buffer_size_without_const_prop_pass = (
            edge_program_manager.to_executorch()
            .executorch_program.execution_plan[0]
            .non_const_buffer_sizes
        )
        self.assertTrue(
            non_const_buffer_size_with_const_prop_pass[1]
            < non_const_buffer_size_without_const_prop_pass[1]
        )

    # cant compare plans directly with __eq__ because of the plan names, and data_buffer_idx in tensor values
    def _compare_execution_plans(
        self, plan_single: ExecutionPlan, plan_merged: ExecutionPlan
    ) -> None:
        self.assertEqual(
            plan_single.container_meta_type,
            plan_merged.container_meta_type,
        )
        self.assertEqual(
            plan_single.inputs,
            plan_merged.inputs,
        )
        self.assertEqual(
            plan_single.outputs,
            plan_merged.outputs,
        )
        self.assertEqual(
            plan_single.chains,
            plan_merged.chains,
        )
        self.assertEqual(
            plan_single.operators,
            plan_merged.operators,
        )
        self.assertEqual(
            plan_single.non_const_buffer_sizes,
            plan_merged.non_const_buffer_sizes,
        )
        self.assertEqual(
            len(plan_single.values),
            len(plan_merged.values),
        )
        for i in range(0, len(plan_single.values)):
            single_val = plan_single.values[i].val
            merged_val = plan_merged.values[i].val
            if isinstance(single_val, Tensor):
                # constant buffer index might be different as the constant buffer is shared between plans
                self.assertTrue(isinstance(merged_val, Tensor))
                self.assertEqual(single_val.storage_offset, merged_val.storage_offset)
                self.assertEqual(single_val.scalar_type, merged_val.scalar_type)
                self.assertEqual(single_val.sizes, merged_val.sizes)
                self.assertEqual(single_val.dim_order, merged_val.dim_order)
                self.assertEqual(single_val.requires_grad, merged_val.requires_grad)
                self.assertEqual(single_val.layout, merged_val.layout)
                self.assertEqual(single_val.allocation_info, merged_val.allocation_info)
                self.assertEqual(single_val.shape_dynamism, merged_val.shape_dynamism)
            else:
                self.assertEqual(single_val, merged_val)

    def test_emit_memory_format_valid(self) -> None:
        class SimpleLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                contiguous = x.to(
                    dtype=torch.float32, memory_format=torch.contiguous_format
                )
                preserve = x.to(
                    dtype=torch.float32, memory_format=torch.preserve_format
                )
                return contiguous + preserve

        # Should succeed at exporting model with legal memory format (contiguous, preserve)
        model = SimpleLinear()
        inputs = (torch.ones(10, 5),)
        try:
            to_edge(
                export(model, inputs, strict=True),
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            ).to_executorch()
        except:
            self.fail("Failed to export model with legal memory format")

    def test_emit_memory_format_invalid(self) -> None:
        class SimpleLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x.to(dtype=torch.float32, memory_format=torch.channels_last)

        # Failure expected when exporting model with illegal memory format (channels_last) when not using dim_order
        model = SimpleLinear()
        inputs = (torch.ones(10, 5, 2, 1),)
        with self.assertRaises(InternalError):
            to_edge(
                export(model, inputs, strict=True),
                compile_config=exir.EdgeCompileConfig(
                    _check_ir_validity=False, _skip_dim_order=True
                ),
            ).to_executorch()

        # Success if you use dim_order
        to_edge(
            export(model, inputs, strict=True),
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False, _skip_dim_order=False
            ),
        ).to_executorch()

    def test_emit_multiple_entry_points(self) -> None:
        class SimpleLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)
                self.linear2 = torch.nn.Linear(5, 5)

            def forward_relu(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.relu(self.linear(x))

            def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.sigmoid(self.linear2(x))

        model = SimpleLinear()
        inputs = (torch.ones(10, 5),)
        with patch_forward(model, model.forward_relu):
            program_relu = to_edge(
                export(model, inputs, strict=True),
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            ).to_executorch()
        with patch_forward(model, model.forward_sigmoid):
            program_sigmoid = to_edge(
                export(model, inputs, strict=True),
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            ).to_executorch()
        exir_input = {
            "forward_relu": program_relu.exported_program(),
            "forward_sigmoid": program_sigmoid.exported_program(),
        }
        merged_program = emit_program(exir_input, False).program
        self.assertEqual(len(merged_program.execution_plan), 2)

        self.assertEqual(
            merged_program.execution_plan[0].name,
            "forward_relu",
        )
        self.assertEqual(
            merged_program.execution_plan[1].name,
            "forward_sigmoid",
        )
        # reserved spot, weight, bias
        self.assertEqual(
            len(program_sigmoid._emitter_output.program.constant_buffer),
            3,
        )
        self.assertEqual(
            len(program_relu._emitter_output.program.constant_buffer),
            3,
        )
        # sum of the entry points minus 1 because we only have one reserved spot still
        self.assertEqual(
            len(merged_program.constant_buffer),
            len(program_sigmoid._emitter_output.program.constant_buffer)
            + len(program_relu._emitter_output.program.constant_buffer)
            - 1,
        )

        self._compare_execution_plans(
            merged_program.execution_plan[0],
            program_relu._emitter_output.program.execution_plan[0],
        )
        self._compare_execution_plans(
            merged_program.execution_plan[1],
            program_sigmoid._emitter_output.program.execution_plan[0],
        )

    def test_emit_weight_deduplication(self) -> None:
        class SimpleLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward_relu(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.relu(self.linear(x))

            def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.sigmoid(self.linear(x))

        model = SimpleLinear()
        inputs = (torch.ones(10, 5),)
        with patch_forward(model, model.forward_relu):
            program_relu = to_edge(export(model, inputs, strict=True)).to_executorch()
        with patch_forward(model, model.forward_sigmoid):
            program_sigmoid = to_edge(
                export(model, inputs, strict=True)
            ).to_executorch()
        exir_input = {
            "forward_relu": program_relu.exported_program(),
            "forward_sigmoid": program_sigmoid.exported_program(),
        }
        merged_program = emit_program(exir_input, False).program
        self.assertEqual(len(merged_program.execution_plan), 2)

        # reserved spot, weight, bias
        self.assertEqual(
            len(program_sigmoid._emitter_output.program.constant_buffer),
            3,
        )
        self.assertEqual(
            len(program_relu._emitter_output.program.constant_buffer),
            3,
        )
        # weights are shared between entry points so the merged one should deduplicate everything
        self.assertEqual(len(merged_program.constant_buffer), 3)

        self._compare_execution_plans(
            merged_program.execution_plan[0],
            program_relu._emitter_output.program.execution_plan[0],
        )
        self._compare_execution_plans(
            merged_program.execution_plan[1],
            program_sigmoid._emitter_output.program.execution_plan[0],
        )

    def test_emit_execution_plans_sorted(self) -> None:
        class Simple(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def a(self, x: torch.Tensor) -> torch.Tensor:
                return x

            def b(self, x: torch.Tensor) -> torch.Tensor:
                return x

            def c(self, x: torch.Tensor) -> torch.Tensor:
                return x

        model = Simple()
        inputs = (torch.ones(10, 5),)

        def make_program(
            fn,
            inputs,
        ) -> "ExecutorchProgramManager":
            return to_edge(
                export(WrapperModule(fn), inputs, strict=True)
            ).to_executorch()

        program_a = make_program(model.a, inputs)
        program_b = make_program(model.b, inputs)
        program_c = make_program(model.c, inputs)

        exir_input = {
            "b": program_b.exported_program(),
            "c": program_c.exported_program(),
            "a": program_a.exported_program(),
        }
        merged_program = emit_program(exir_input, False).program
        self.assertEqual(len(merged_program.execution_plan), 3)
        self.assertEqual(merged_program.execution_plan[0].name, "a")
        self.assertEqual(merged_program.execution_plan[1].name, "b")
        self.assertEqual(merged_program.execution_plan[2].name, "c")

        # Create a second program equivalent to the first, but the input is in a different order.
        # python dicts are instertion ordered
        exir_input2 = {
            "a": program_b.exported_program(),
            "b": program_c.exported_program(),
            "c": program_a.exported_program(),
        }
        merged_program2 = emit_program(exir_input2, False).program
        self.assertEqual(
            merged_program2.execution_plan[0], merged_program.execution_plan[0]
        )
        self.assertEqual(
            merged_program2.execution_plan[1], merged_program.execution_plan[1]
        )
        self.assertEqual(
            merged_program2.execution_plan[2], merged_program.execution_plan[2]
        )

    def test_upper_bound_memory_planning_respect_input_constraints(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, k: torch.Tensor) -> torch.Tensor:
                k = torch.cat((k, torch.ones(1, 4)))
                return k

        func = Foo()

        k = torch.rand(2, 4)
        dim0_k = Dim("dim0_k", max=3)
        dynamic_shapes = {"k": {0: dim0_k}}
        captured = export(func, (k,), dynamic_shapes=dynamic_shapes, strict=True)
        edge = to_edge(captured)
        from executorch.exir.passes import MemoryPlanningPass

        config = exir.ExecutorchBackendConfig(
            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
            memory_planning_pass=MemoryPlanningPass(
                # allow_lifetime_and_storage_overlap: bool = False,
                alloc_graph_input=True,
                alloc_graph_output=False,
            ),
        )

        exe_prog = edge.to_executorch(config)
        program = exe_prog._emitter_output.program
        exir.print_program.pretty_print(exe_prog._emitter_output.program.execution_plan)
        execution_plan = program.execution_plan[0]
        self.check_tensor_buffer_loc(0, execution_plan.values, 0, 1, 0)
        self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48)

    def test_emit_prims(self) -> None:
        tensor_output = torch.rand(1, 4)
        tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)]

        class Simple(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)
                self.x: int = 3
                self.y = 2

            def get_ints(self) -> Tuple[int]:
                return (self.x, self.y)

            def get_str(self) -> str:
                return "foo"

            def get_tensor(self) -> torch.Tensor:
                return tensor_output

            def get_tensor_list(self) -> List[torch.Tensor]:
                return tensor_list_output

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.sigmoid(self.linear(x))

        model = Simple()
        inputs = (torch.ones(10, 5),)
        program = to_edge(export(model, inputs, strict=True)).to_executorch()
        exir_input = {
            "forward": program.exported_program(),
        }
        getters = {}
        getters["get_ints"] = model.get_ints()
        getters["get_str"] = model.get_str()
        getters["get_tensor"] = model.get_tensor()
        getters["get_tensor_list"] = model.get_tensor_list()

        merged_program = emit_program(exir_input, False, getters).program

        self.assertEqual(len(merged_program.execution_plan), 5)

        self.assertEqual(
            merged_program.execution_plan[0].name,
            "forward",
        )
        self.assertEqual(
            merged_program.execution_plan[1].name,
            "get_ints",
        )
        self.assertEqual(
            merged_program.execution_plan[2].name,
            "get_str",
        )
        self.assertEqual(
            merged_program.execution_plan[3].name,
            "get_tensor",
        )
        self.assertEqual(
            merged_program.execution_plan[4].name,
            "get_tensor_list",
        )

        # no instructions in a getter
        self.assertEqual(
            len(merged_program.execution_plan[1].chains[0].instructions),
            0,
        )
        # 2 outputs for the flattened tuple
        self.assertEqual(
            len(merged_program.execution_plan[1].outputs),
            2,
        )
        # outputs are 0 and 1 in the values table
        self.assertEqual(
            merged_program.execution_plan[1].outputs,
            [0, 1],
        )
        # value 0 is 3
        self.assertEqual(
            # pyre-ignore
            merged_program.execution_plan[1].values[0].val.int_val,
            3,
        )
        self.assertEqual(
            # pyre-ignore
            merged_program.execution_plan[1].values[1].val.int_val,
            2,
        )
        self.assertEqual(
            len(merged_program.execution_plan[2].outputs),
            1,
        )
        self.assertEqual(
            # pyre-ignore
            merged_program.execution_plan[2].values[0].val.string_val,
            "foo",
        )
        self.assertEqual(len(merged_program.execution_plan[3].outputs), 1)
        self.assertEqual(len(merged_program.execution_plan[4].outputs), 2)

        merged_program = to_edge(
            export(model, inputs, strict=True), constant_methods=getters
        ).to_executorch()
        executorch_module = _load_for_executorch_from_buffer(merged_program.buffer)
        torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output)
        model_output = executorch_module.run_method("get_tensor_list", [])
        for i in range(len(tensor_list_output)):
            torch.allclose(model_output[i], tensor_list_output[i])

    def test_emit_debug_handle_map(self) -> None:
        mul_model = Mul()
        program_mul = to_edge(
            export(mul_model, mul_model.get_random_inputs(), strict=True)
        ).to_executorch()
        # this triggers the actual emission of the graph
        program_mul._emitter_output.program
        self.assertIsNotNone(program_mul.debug_handle_map)

    def test_final_graph_module_update_debug_handle(self) -> None:
        class SimpleAddMul(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                a = x + 1
                return a * 2

        mul_model = SimpleAddMul()
        program_mul = to_edge(
            export(mul_model, (torch.ones(2, 2),), strict=True)
        ).to_executorch()

        # this triggers the actual emission of the graph
        program = program_mul._emitter_output.program
        node = None
        program.execution_plan[0].chains[0].instructions[0].instr_args.op_index

        # Find the multiplication node in the graph that was emitted.
        for node in program_mul.exported_program().graph.nodes:
            if (
                node.target == torch.ops.aten.mul.out
                or node.target == torch.ops.aten.mul.Scalar_out
            ):
                break
        self.assertIsNotNone(node)

        idx = 0
        # Find the multiplication instruction in the program that was emitted.
        for idx in range(len(program.execution_plan[0].chains[0].instructions)):
            instruction = program.execution_plan[0].chains[0].instructions[idx]
            op_index = instruction.instr_args.op_index
            if "mul" in program.execution_plan[0].operators[op_index].name:
                break

        # The instruction id of the multiplication instruction and the debug handle of the
        # multiplication node in the graph module (which was updated in the emitter to be
        # the same as the instruction id) must be the same.
        self.assertEqual(
            idx,
            node.meta.get("debug_handle"),
        )

    def test_delegate_with_input_list(self) -> None:
        class BackendWithCompilerExample(BackendDetails):
            @staticmethod
            def preprocess(
                edge_program,
                compile_specs,
            ) -> bytes:
                return PreprocessResult(
                    processed_bytes=bytes(str("test"), encoding="utf8"),
                    debug_handle_map=None,
                )

        class TestModel(nn.Module):
            def __init__(self):
                super(TestModel, self).__init__()

            def forward(self, x):
                return torch.cat(x)

        inputs = ([torch.ones(2, 2), torch.ones(2, 2)],)
        model = TestModel()
        edgeir_m = to_edge(export(model, inputs, strict=True))
        lowered_module = to_backend(
            "BackendWithCompilerExample", edgeir_m.exported_program(), []
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_module = lowered_module

            def forward(self, list_a):
                return self.lowered_module(list_a)

        composite_model = CompositeModule()
        exec_prog = to_edge(
            export(composite_model, inputs, strict=True),
        ).to_executorch()
        exec_prog.buffer

    def test_delegate_with_input_tuple(self) -> None:
        class BackendWithCompilerExample(BackendDetails):
            @staticmethod
            def preprocess(
                edge_program,
                compile_specs,
            ) -> bytes:
                return PreprocessResult(
                    processed_bytes=bytes(str("test"), encoding="utf8"),
                    debug_handle_map=None,
                )

        class AddMulModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, input):  # a, x, b):
                y = torch.mm(input[0], input[1])
                z = torch.add(y, input[2])
                return z

        model_inputs = ((torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)),)
        model = AddMulModule()
        edgeir_m = to_edge(export(model, model_inputs, strict=True))
        lowered_module = to_backend(
            "BackendWithCompilerExample", edgeir_m.exported_program(), []
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_module = lowered_module

            def forward(self, list_a):
                return self.lowered_module(list_a)

        composite_model = CompositeModule()
        exec_prog = to_edge(
            export(composite_model, model_inputs, strict=True),
        ).to_executorch()
        exec_prog.buffer

    def test_delegate_mapping(self) -> None:
        debug_handle_map = {1: [1, 2]}

        class BackendWithCompilerExample(BackendDetails):
            @staticmethod
            def preprocess(
                edge_program,
                compile_specs,
            ) -> bytes:
                return PreprocessResult(
                    processed_bytes=bytes(str("test"), encoding="utf8"),
                    debug_handle_map=debug_handle_map,
                )

        class TestModel(nn.Module):
            def __init__(self):
                super(TestModel, self).__init__()

            def forward(self, x, y):
                return torch.add(x, y)

        inputs = (torch.ones(2, 2), torch.ones(2, 2))
        model = TestModel()
        edgeir_m = to_edge(export(model, inputs, strict=True))
        lowered_module = to_backend(
            "BackendWithCompilerExample", edgeir_m.exported_program(), []
        )

        class CompositeModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowered_module = lowered_module

            def forward(self, x, y):
                return self.lowered_module(x, y)

        composite_model = CompositeModule()
        exec_prog = to_edge(
            export(composite_model, inputs, strict=True),
        ).to_executorch()
        # Reading the program triggers the call to emit_program underneath which
        # we need to be done for our test to succeed.
        exec_prog._emitter_output.program
        self.assertIsNotNone(exec_prog.delegate_map)
        self.assertIsNotNone(exec_prog.delegate_map.get("forward"))
        self.assertIsNotNone(exec_prog.delegate_map.get("forward").get(0))
        self.assertEqual(
            exec_prog.delegate_map.get("forward").get(0).get("name"),
            "BackendWithCompilerExample",
        )
        self.assertTrue(
            len(exec_prog.delegate_map.get("forward").get(0).get("delegate_map")) != 0
        )

    def test_emit_weight_view(self) -> None:
        class ModWithWeightViews(nn.Module):
            def __init__(self):
                super(ModWithWeightViews, self).__init__()
                self.W = torch.nn.Parameter(torch.randn(2))
                self.W1 = self.W[:1]
                self.W2 = self.W[1:]

            def forward(self, x):
                return self.W1 + self.W2 + x

        model = ModWithWeightViews()
        # each weight is a view of the same storage
        self.assertEqual(model.W1.nbytes, 4)
        self.assertEqual(model.W1.untyped_storage().nbytes(), 8)
        self.assertEqual(model.W2.nbytes, 4)
        self.assertEqual(model.W2.untyped_storage().nbytes(), 8)
        program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch()

        program = program._emitter_output.program
        # each emitted weight is not a view
        self.assertEqual(len(program.constant_buffer[1].storage), 4)
        self.assertEqual(len(program.constant_buffer[2].storage), 4)

    def test_non_persistent_buffer(self) -> None:
        class NonPersistentBuffer(nn.Module):
            def __init__(self):
                super(NonPersistentBuffer, self).__init__()
                self.register_buffer("buf", torch.tensor([1]), persistent=False)

            def forward(self, x):
                return x + self.buf

        model = NonPersistentBuffer()
        program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch()
        program = program._emitter_output.program
        # confirm that the buffer was emitted
        self.assertEqual(len(program.constant_buffer), 2)
        self.assertEqual(len(program.constant_buffer[1].storage), 8)

    def test_emit_lifted_tensor_constant(self) -> None:
        class LiftedTensorConstants(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float)
                return x

        model = LiftedTensorConstants()
        # Specify that we want to move non-lifted constants to external file
        et_cfg = ExecutorchBackendConfig(external_constants=True)
        program = to_edge(
            export(model, (torch.ones(3, 2),), strict=True)
        ).to_executorch(et_cfg)
        program = program._emitter_output.program
        exec_plan = program.execution_plan[0]
        # There should only be 1 input to this model.
        self.assertEqual(len(exec_plan.inputs), 1)
        self.assertEqual(len(program.constant_buffer), 2)
        self.assertEqual(len(program.constant_buffer[1].storage), 24)

    def test_emit_lifted_constant(self) -> None:
        class LiftedConstants(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = x + 1
                return x

        model = LiftedConstants()
        # Specify that we want to move non-lifted constants to external file
        et_cfg = ExecutorchBackendConfig(external_constants=True)
        program = to_edge(
            export(model, (torch.ones(3, 2),), strict=True)
        ).to_executorch(et_cfg)

        program = program._emitter_output.program
        exec_plan = program.execution_plan[0]
        # There should only be 1 input to this model.
        self.assertEqual(len(exec_plan.inputs), 1)
        self.assertEqual(len(program.constant_buffer), 2)
        self.assertEqual(len(program.constant_buffer[1].storage), 8)

    def test_mutable_buffers(self) -> None:
        def count_copies(gm: torch.fx.GraphModule) -> int:
            return sum(
                (
                    node.target == torch.ops.aten.copy_
                    or node.target == exir_ops.edge.aten.copy_.default
                )
                for node in gm.graph.nodes
            )

        class MutableStateModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.register_buffer("state", torch.zeros(1))

            def forward(self, x):
                y = x + self.state
                self.state.add_(1)
                return y

        model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True))
        model = model.to_executorch()
        model.dump_executorch_program(True)
        self.assertTrue(
            model.executorch_program.execution_plan[0].values[0].val.allocation_info
            is not None
        )
        executorch_module = _load_for_executorch_from_buffer(model.buffer)
        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

    def test_mutable_buffers_without_memplanned_inputs(self) -> None:
        def count_copies(gm: torch.fx.GraphModule) -> int:
            return sum(
                (
                    node.target == torch.ops.aten.copy_
                    or node.target == exir_ops.edge.aten.copy_.default
                )
                for node in gm.graph.nodes
            )

        class MutableStateModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.register_buffer("state", torch.zeros(1))

            def forward(self, x):
                y = x + self.state
                self.state.add_(1)
                return y

        model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True))
        model = model.to_executorch(
            config=ExecutorchBackendConfig(
                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
            )
        )
        model.dump_executorch_program(True)
        self.assertTrue(
            model.executorch_program.execution_plan[0].values[0].val.allocation_info
            is not None
        )
        executorch_module = _load_for_executorch_from_buffer(model.buffer)
        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

    def test_infinity_in_model(self) -> None:
        class InfinityMaskModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)

            def forward(self, x):
                masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
                return masked_weights

        model = to_edge(export(InfinityMaskModel(), (torch.randn(2, 2),), strict=True))

        # Confirm that we can serialize the model with infinity in it.
        model = model.to_executorch()

        # Assert that the infinity is stored as a string "-inf".
        values = model.executorch_program.execution_plan[0].values
        self.assertEqual(values[5].val, Double(double_val=float("-inf")))

        # Confirm that we can also deserialize the model with infinity in it.
        deserialize = deserialize_pte_binary(model.buffer)
        self.assertEqual(
            deserialize.program.execution_plan,
            model.executorch_program.execution_plan,
        )

    def test_mutate_input_tensor(self) -> None:
        class MutateInputTensorModule(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x.add_(1)

        model = to_edge(
            export(MutateInputTensorModule(), (torch.zeros(1),), strict=True)
        ).to_executorch(
            config=ExecutorchBackendConfig(
                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False)
            )
        )
        executorch_model = _load_for_executorch_from_buffer(model.buffer)
        input = torch.zeros(1)
        executorch_model(input)
        self.assertEqual(input, torch.ones(1))

    def test_constant_tagged_tensors(self) -> None:
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.linear(x)

        model = to_edge(
            export(LinearModule(), (torch.ones(5, 5),), strict=True)
        ).to_executorch(
            config=ExecutorchBackendConfig(
                external_constants=True,
            )
        )
        emitter_output = model._emitter_output
        # Check that constant_buffer is empty besides the non-constant placeholder 0.
        self.assertEqual(len(emitter_output.program.constant_buffer), 1)
        # Check that constant weights are in the external constant buffer.
        self.assertEqual(len(emitter_output.external_constant_buffer), 2)
        # Setting external_constants=True, saves all constants to the key
        # '_default_external_constant'.
        external_map = emitter_output.external_constant_map[
            "_default_external_constant"
        ]
        self.assertEqual(len(external_map), 2)
        self.assertEqual(external_map["linear.weight"], 0)
        self.assertEqual(external_map["linear.bias"], 1)

    def test_constant_tagged_tensors_custom(self) -> None:
        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.linear(x)

        model = to_edge(
            export(LinearModule(), (torch.ones(5, 5),), strict=True)
        ).to_executorch(
            config=ExecutorchBackendConfig(
                external_constants=lambda x: (
                    "linear_weight" if "weight" in x.name else None
                ),
            )
        )
        emitter_output = model._emitter_output
        # constant_buffer contains placeholder and linear bias.
        self.assertEqual(len(emitter_output.program.constant_buffer), 2)
        # external constant buffer contains linear weight.
        self.assertEqual(len(emitter_output.external_constant_buffer), 1)
        # The lambda saves all constants to the key 'linear_weight'.
        external_map = emitter_output.external_constant_map["linear_weight"]
        self.assertEqual(len(external_map), 1)
        self.assertEqual(external_map["linear.weight"], 0)

    def test_constant_tagged_tensor_dedup(self) -> None:
        class ConstantModule(nn.Module):
            def __init__(self):
                super().__init__()
                constant = torch.tensor([1.0, 2.0, 3.0])

                # Register the same value with two different names as persistent buffers
                self.register_buffer("c0", constant.clone(), persistent=True)
                self.register_buffer("c1", constant.clone(), persistent=True)

            def forward(self, x):
                return x + self.c0 + self.c1

        model = to_edge(
            export(ConstantModule(), (torch.ones(1, 3),), strict=True)
        ).to_executorch(
            config=ExecutorchBackendConfig(
                external_constants=True,
            )
        )
        emitter_output = model._emitter_output
        # constant_buffer is empty besides the non-constant placeholder 0.
        self.assertEqual(len(emitter_output.program.constant_buffer), 1)
        # only one item in the external constant buffer.
        self.assertEqual(len(emitter_output.external_constant_buffer), 1)
        # Setting external_constants=True, saves all constants to the key
        # '_default_external_constant'.
        external_map = emitter_output.external_constant_map[
            "_default_external_constant"
        ]
        self.assertEqual(len(external_map), 2)
        self.assertEqual(external_map["c0"], 0)
        self.assertEqual(external_map["c1"], 0)

    def test_constant_tagged_tensor_dedup_2(self) -> None:
        class ConstantModule(nn.Module):
            def __init__(self):
                super().__init__()
                constant0_4 = torch.tensor([1.0, 2.0, 3.0])
                constant4_5 = torch.tensor([2.0, 3.0, 4.0])

                # Register the same value with two different names as persistent buffers
                self.register_buffer("c0", constant0_4.clone(), persistent=True)
                self.register_buffer("c1", constant0_4.clone(), persistent=True)
                self.register_buffer("c2", constant0_4.clone(), persistent=True)
                self.register_buffer("c3", constant0_4.clone(), persistent=True)
                self.register_buffer("c4", constant4_5.clone(), persistent=True)
                self.register_buffer("c5", constant4_5.clone(), persistent=True)

            def forward(self, x):
                return x + self.c0 + self.c1 + self.c2 + self.c3 + self.c4 + self.c5

        model = to_edge(
            export(ConstantModule(), (torch.ones(1, 3),), strict=True)
        ).to_executorch(
            config=ExecutorchBackendConfig(
                external_constants=True,
            )
        )
        emitter_output = model._emitter_output
        # constant_buffer is empty besides the non-constant placeholder 0.
        self.assertEqual(len(emitter_output.program.constant_buffer), 1)
        # Two items in the external constant buffer.
        self.assertEqual(len(emitter_output.external_constant_buffer), 2)
        # Setting external_constants=True, saves all constants to the key
        # '_default_external_constant'.
        external_map = emitter_output.external_constant_map[
            "_default_external_constant"
        ]
        self.assertEqual(len(external_map), 6)
        self.assertEqual(external_map["c0"], 0)
        self.assertEqual(external_map["c1"], 0)
        self.assertEqual(external_map["c2"], 0)
        self.assertEqual(external_map["c3"], 0)
        self.assertEqual(external_map["c4"], 1)
        self.assertEqual(external_map["c5"], 1)

    def test_delegate_deduplicate(self) -> None:
        class SharedModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, x):
                return self.linear(x)

        class Module1(torch.nn.Module):
            def __init__(self, shared_module):
                super().__init__()
                self.shared_module = shared_module

            def forward(self, x):
                return self.shared_module(x)

        class Module2(torch.nn.Module):
            def __init__(self, shared_module):
                super().__init__()
                self.shared_module = shared_module

            def forward(self, x):
                return self.shared_module(x)

        shared_module = SharedModule()
        module_1 = Module1(shared_module)
        module_2 = Module2(shared_module)
        example_inputs = (torch.randn(2, 2),)
        module_1(*example_inputs)
        module_2(*example_inputs)

        ep1 = export(module_1, example_inputs, strict=True)
        ep2 = export(module_2, example_inputs, strict=True)

        edge_program_manager = exir.to_edge(
            {"forward1": ep1, "forward2": ep2},
            compile_config=exir.EdgeCompileConfig(
                _check_ir_validity=False, _use_edge_ops=True
            ),
        )

        edge_program_manager = edge_program_manager.to_backend(
            ExecutorBackendPartitioner()
        ).to_executorch()

        # Check that there is only one delegate because two methods are exactly the same
        self.assertEqual(
            len(edge_program_manager.executorch_program.backend_delegate_data), 1
        )

    def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
        class LowerableSubModel(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.sin(x)

        lowered = LowerableSubModel()
        example_input = (torch.ones(1),)

        lowered_edge = to_edge(export(lowered, example_input))

        from executorch.exir.backend.compile_spec_schema import CompileSpec

        compile_specs1 = [CompileSpec("config", b"fast")]
        compile_specs2 = [CompileSpec("config", b"small")]
        lowered_module1 = to_backend(
            "BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
        )
        lowered_module2 = to_backend(
            "BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
        )

        class CompositeModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.lowerable1 = lowered_module1
                self.lowerable2 = lowered_module2

            def forward(self, x):
                a = self.lowerable1(x)
                b = self.lowerable2(a)
                return a, b

        composite_model = CompositeModel()
        model_inputs = (torch.ones(1),)
        edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()

        exported_program = edge_prog.exported_program()
        program = emit_program({"method1": exported_program}, False).program
        self.assertEqual(len(program.execution_plan), 1)

        plan = program.execution_plan[0]
        # Two delegates that point to the same blob.
        self.assertEqual(len(plan.delegates), 2)
        self.assertEqual(plan.delegates[0].processed.index, 0)
        self.assertEqual(plan.delegates[1].processed.index, 0)
        # Compile specs are different.
        self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
        self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
        # Only one delegate blob in the backend_delegate_data.
        self.assertEqual(len(program.backend_delegate_data), 1)

    def test_constant_tagged_mutable_tensors(self) -> None:
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(2, 2)

            def forward(self, x):
                return self.linear(x)

        # On device training requires the loss to be embedded in the model (and be the first output).
        # We wrap the original model here and add the loss calculation. This will be the model we export.
        class TrainingNet(nn.Module):
            def __init__(self, net):
                super().__init__()
                self.net = net
                self.loss = nn.CrossEntropyLoss()

            def forward(self, input, label):
                pred = self.net(input)
                return self.loss(pred, label), pred.detach().argmax(dim=1)

        net = TrainingNet(Net())

        # Captures the forward graph. The graph will look similar to the model definition now.
        ep = export(
            net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True
        )
        # Captures the backward graph. The exported_program now contains the joint forward and backward graph.
        ep = _export_forward_backward(ep)
        # Lower the graph to edge dialect.
        ep = to_edge(ep)
        # Lower the graph to executorch.
        ep = ep.to_executorch(
            config=ExecutorchBackendConfig(external_mutable_weights=True)
        )

        emitter_output = ep._emitter_output
        # Check that constant_buffer is empty besides the non-constant placeholder 0.
        self.assertEqual(len(emitter_output.program.constant_buffer), 1)
        # Check that constant weights are in the external constant buffer.
        self.assertEqual(len(emitter_output.external_constant_buffer), 2)
        # Setting external_mutable_weights=True, saves all constants with an associated gradient to the key
        # '_default_external_constant'.
        external_map = emitter_output.external_constant_map[
            "_default_external_constant"
        ]
        self.assertEqual(external_map["net.linear.weight"], 0)
        self.assertEqual(external_map["net.linear.bias"], 1)

    def test_emit_mutable_buffer_names(self) -> None:
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(2, 2)
                self.register_buffer("buffer", torch.zeros(1, 2))

            def forward(self, x):
                self.buffer.add_(1)
                return self.linear(x) + self.buffer

        net = Net()

        ep = export(net, (torch.randn(1, 2),), strict=True)
        # Lower the graph to edge dialect.
        ep = to_edge(ep)
        # Lower the graph to executorch.
        ep = ep.to_executorch(
            config=ExecutorchBackendConfig(
                emit_mutable_buffer_names=True,
                memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False),
            )
        )
        for val in ep.executorch_program.execution_plan[0].values:
            if isinstance(val, Tensor) and val.extra_tensor_info:
                self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")
                self.assertEqual(val.allocation_info, None)

    def test_emit_mutable_buffer_names_fails(self) -> None:
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(2, 2)
                self.register_buffer("buffer", torch.zeros(1, 2))

            def forward(self, x):
                self.buffer.add_(1)
                return self.linear(x) + self.buffer

        net = Net()

        ep = export(net, (torch.randn(1, 2),), strict=True)
        # Lower the graph to edge dialect.
        ep = to_edge(ep)
        # Lower the graph to executorch.
        # Must emit mutable buffer names if we don't allocate mutable buffers
        with self.assertRaises(InternalError):
            ep.to_executorch(
                config=ExecutorchBackendConfig(
                    emit_mutable_buffer_names=False,
                    memory_planning_pass=MemoryPlanningPass(
                        alloc_mutable_buffers=False
                    ),
                )
            )

    def test_emit_sym_min_max(self) -> None:
        class SymMaxModel(nn.Module):
            def __init__(self, test_min=False):
                super().__init__()
                self.test_min = test_min

            def forward(self, x):
                # Get size of 0th dimension - this creates sym_size op
                batch_size = x.shape[0]
                # Compute max of batch_size and 10 - this should create sym_max op
                if self.test_min:
                    out_size = min(batch_size, 10)
                else:
                    out_size = max(batch_size, 10)
                # Create a 1D tensor of zeros with the computed size
                result = torch.zeros(out_size, dtype=x.dtype, device=x.device)
                return result

        for validate_min in [True, False]:
            model = SymMaxModel(test_min=validate_min)
            test_inputs = [
                torch.randn(5, 3),  # should output zeros(10) for max zeros(5) for min
                torch.randn(15, 3),  # should output zeros(15) for max zeros(10) for min
                torch.randn(10, 3),  # should output zeros(10) for max zeros(10) for min
            ]
            model.eval()
            reference_outputs = []
            with torch.no_grad():
                for _, inp in enumerate(test_inputs):
                    output = model(inp)
                    reference_outputs.append(output)

            batch_dim = Dim("batch", min=1, max=20)
            dynamic_shapes = {"x": {0: batch_dim}}  # 0th dimension is dynamic
            exported_program = torch.export.export(
                model, (test_inputs[0],), dynamic_shapes=dynamic_shapes
            )
            edge_program = to_edge(
                exported_program,
                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
            )
            et_program = edge_program.to_executorch()
            program_buffer = et_program.buffer
            et_module = _load_for_executorch_from_buffer(program_buffer)
            for _, (inp, expected) in enumerate(zip(test_inputs, reference_outputs)):
                # Execute with ExecuTorch
                et_output = et_module.forward([inp])
                et_result = et_output[0]  # Get first output
                # Compare results
                self.assertTrue(expected.shape == et_result.shape)
                self.assertTrue(torch.allclose(expected, et_result))
